{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Use Ragas to evaluate the OpenAI Assistant\n",
    "\n",
    "**Please note that this test requires a large amount of OpenAI api token consumption. Please read it carefully and Pay attention to the number of times you request access.**"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 1. Prepare environment and data\n",
    "\n",
    "Before starting, you must set OPENAI_API_KEY in your environment variables."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Install pip dependencies"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# ! python -m pip install openai beir pandas ragas==0.0.17"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Download [Financial Opinion Mining and Question Answering (fiqa) Dataset](https://sites.google.com/view/fiqa/) data if it not exists in your local space. We convert it into a ragas form that is easier to process, referring from this [script](https://github.com/explodinggradients/ragas/blob/main/experiments/baselines/fiqa/dataset-exploration-and-baseline.ipynb)."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1706\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "from datasets import Dataset\n",
    "from beir import util\n",
    "\n",
    "\n",
    "def prepare_fiqa_without_answer(knowledge_path):\n",
    "    dataset_name = \"fiqa\"\n",
    "\n",
    "    if not os.path.exists(os.path.join(knowledge_path, f'{dataset_name}.zip')):\n",
    "        url = (\n",
    "            \"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip\".format(\n",
    "                dataset_name\n",
    "            )\n",
    "        )\n",
    "        util.download_and_unzip(url, knowledge_path)\n",
    "\n",
    "    data_path = os.path.join(knowledge_path, 'fiqa')\n",
    "    with open(os.path.join(data_path, \"corpus.jsonl\")) as f:\n",
    "        cs = [pd.Series(json.loads(l)) for l in f.readlines()]\n",
    "\n",
    "    corpus_df = pd.DataFrame(cs)\n",
    "\n",
    "    corpus_df = corpus_df.rename(columns={\"_id\": \"corpus-id\", \"text\": \"ground_truth\"})\n",
    "    corpus_df = corpus_df.drop(columns=[\"title\", \"metadata\"])\n",
    "    corpus_df[\"corpus-id\"] = corpus_df[\"corpus-id\"].astype(int)\n",
    "    corpus_df.head()\n",
    "\n",
    "    with open(os.path.join(data_path, \"queries.jsonl\")) as f:\n",
    "        qs = [pd.Series(json.loads(l)) for l in f.readlines()]\n",
    "\n",
    "    queries_df = pd.DataFrame(qs)\n",
    "    queries_df = queries_df.rename(columns={\"_id\": \"query-id\", \"text\": \"question\"})\n",
    "    queries_df = queries_df.drop(columns=[\"metadata\"])\n",
    "    queries_df[\"query-id\"] = queries_df[\"query-id\"].astype(int)\n",
    "    queries_df.head()\n",
    "\n",
    "    splits = [\"dev\", \"test\", \"train\"]\n",
    "    split_df = {}\n",
    "    for s in splits:\n",
    "        split_df[s] = pd.read_csv(os.path.join(data_path, f\"qrels/{s}.tsv\"), sep=\"\\t\").drop(\n",
    "            columns=[\"score\"]\n",
    "        )\n",
    "\n",
    "    final_split_df = {}\n",
    "    for split in split_df:\n",
    "        df = queries_df.merge(split_df[split], on=\"query-id\")\n",
    "        df = df.merge(corpus_df, on=\"corpus-id\")\n",
    "        df = df.drop(columns=[\"corpus-id\"])\n",
    "        grouped = df.groupby(\"query-id\").apply(\n",
    "            lambda x: pd.Series(\n",
    "                {\n",
    "                    \"question\": x[\"question\"].sample().values[0],\n",
    "                    \"ground_truths\": x[\"ground_truth\"].tolist(),\n",
    "                }\n",
    "            )\n",
    "        )\n",
    "\n",
    "        grouped = grouped.reset_index()\n",
    "        grouped = grouped.drop(columns=\"query-id\")\n",
    "        final_split_df[split] = grouped\n",
    "\n",
    "    return final_split_df\n",
    "\n",
    "\n",
    "knowledge_datas_path = './knowledge_datas'\n",
    "fiqa_path = os.path.join(knowledge_datas_path, 'fiqa_doc.txt')\n",
    "\n",
    "if not os.path.exists(knowledge_datas_path):\n",
    "    os.mkdir(knowledge_datas_path)\n",
    "contexts_list = []\n",
    "answer_list = []\n",
    "\n",
    "final_split_df = prepare_fiqa_without_answer(knowledge_datas_path)\n",
    "\n",
    "docs = []\n",
    "\n",
    "split = 'test'\n",
    "for ds in final_split_df[split][\"ground_truths\"]:\n",
    "    docs.extend([d for d in ds])\n",
    "print(len(docs))\n",
    "\n",
    "docs_str = '\\n'.join(docs)\n",
    "with open(fiqa_path, 'w') as f:\n",
    "    f.write(docs_str)\n",
    "\n",
    "split = 'test'\n",
    "question_list = final_split_df[split][\"question\"].to_list()\n",
    "ground_truth_list = final_split_df[split][\"ground_truths\"].to_list()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Now we have the question list and the ground truth list. And the knowledge documents are prepared in `fiqa_path`.\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 2. Building RAG using OpenAI assistant\n",
    "\n",
    "To get the context content from the annotations returned by Open AI."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import time\n",
    "from openai import OpenAI\n",
    "\n",
    "client = OpenAI()\n",
    "\n",
    "# Set OPENAI_API_KEY in your environment value\n",
    "client.api_key = os.getenv('OPENAI_API_KEY')\n",
    "\n",
    "\n",
    "class OpenAITimeoutException(Exception):\n",
    "    pass\n",
    "\n",
    "\n",
    "def get_content_from_retrieved_message(message):\n",
    "    # Extract the message content\n",
    "    message_content = message.content[0].text\n",
    "    annotations = message_content.annotations\n",
    "    contexts = []\n",
    "    for annotation in annotations:\n",
    "        message_content.value = message_content.value.replace(annotation.text, f'')\n",
    "        if (file_citation := getattr(annotation, 'file_citation', None)):\n",
    "            contexts.append(file_citation.quote)\n",
    "    if len(contexts) == 0:\n",
    "        contexts = ['empty context.']\n",
    "    return message_content.value, contexts\n",
    "\n",
    "\n",
    "def try_get_answer_contexts(assistant_id, question, timeout_seconds=120):\n",
    "    thread = client.beta.threads.create(\n",
    "        messages=[\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": question,\n",
    "            }\n",
    "        ]\n",
    "    )\n",
    "    thread_id = thread.id\n",
    "    run = client.beta.threads.runs.create(\n",
    "        thread_id=thread_id,\n",
    "        assistant_id=assistant_id,\n",
    "    )\n",
    "    start_time = time.time()\n",
    "    while True:\n",
    "        elapsed_time = time.time() - start_time\n",
    "        if elapsed_time > timeout_seconds:\n",
    "            raise Exception(\"OpenAI retrieving answer Timeout！\")\n",
    "\n",
    "        run = client.beta.threads.runs.retrieve(\n",
    "            thread_id=thread_id,\n",
    "            run_id=run.id\n",
    "        )\n",
    "        if run.status == 'completed':\n",
    "            break\n",
    "    messages = client.beta.threads.messages.list(\n",
    "        thread_id=thread_id\n",
    "    )\n",
    "    assert len(messages.data) > 1\n",
    "    res, contexts = get_content_from_retrieved_message(messages.data[0])\n",
    "    response = client.beta.threads.delete(thread_id)\n",
    "    assert response.deleted is True\n",
    "    return contexts, res\n",
    "\n",
    "\n",
    "def get_answer_contexts_from_assistant(question, assistant_id, timeout_seconds=120, retry_num=6):\n",
    "    res = 'failed. please retry.'\n",
    "    contexts = ['failed. please retry.']\n",
    "    try:\n",
    "        for _ in range(retry_num):\n",
    "            try:\n",
    "                contexts, res = try_get_answer_contexts(assistant_id, question, timeout_seconds)\n",
    "                break\n",
    "            except OpenAITimeoutException as e:\n",
    "                print('OpenAI retrieving answer Timeout, retry...')\n",
    "                continue\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "    return res, contexts"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Build assistant and upload knowledge files."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/648 [03:45<?, ?it/s]\n",
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "file = client.files.create(\n",
    "    file=open(fiqa_path, \"rb\"),\n",
    "    purpose='assistants'\n",
    ")\n",
    "\n",
    "# Add the file to the assistant\n",
    "assistant = client.beta.assistants.create(\n",
    "    instructions=\"You are a customer support chatbot. You must use your retrieval tool to retrieve relevant knowledge to best respond to customer queries.\",\n",
    "    model=\"gpt-4-1106-preview\",\n",
    "    tools=[{\"type\": \"retrieval\"}],\n",
    "    file_ids=[file.id]\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 3. Start Ragas Evaluation\n",
    "\n",
    "Note that a large amount of OpenAI api token is consumed. Every time you ask a question and every evaluation, you will ask the OpenAI service. Please pay attention to your token consumption. If you only want to run a small number of tests, you can modify the code to reduce the test size."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "for question in tqdm(question_list):\n",
    "    answer, contexts = get_answer_contexts_from_assistant(question, assistant.id)\n",
    "    # print(f'answer = {answer}')\n",
    "    # print(f'contexts = {contexts}')\n",
    "    # print('=' * 80)\n",
    "    answer_list.append(answer)\n",
    "    contexts_list.append(contexts)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "You can choose the indicators you care about to test.\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from ragas import evaluate\n",
    "from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision, answer_similarity\n",
    "\n",
    "ds = Dataset.from_dict({\"question\": question_list,\n",
    "                        \"contexts\": contexts_list,\n",
    "                        \"answer\": answer_list,\n",
    "                        \"ground_truths\": ground_truth_list})\n",
    "\n",
    "result = evaluate(\n",
    "    ds,\n",
    "    metrics=[\n",
    "        context_precision,\n",
    "        # context_recall,\n",
    "        # faithfulness,\n",
    "        # answer_relevancy,\n",
    "        # answer_similarity,\n",
    "        # answer_correctness,\n",
    "    ],\n",
    "\n",
    ")\n",
    "print(result)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}